#include "nemu.h"
#include "device/mmio.h"
#include "memory/mmu.h"

#define PMEM_SIZE (128 * 1024 * 1024)

#define PTE_ADDR(pte)   ((uint32_t)(pte)&-0xfff)
#define PDX(va)   (((uint32_t)(va)>>22)&0x3ff)
#define PTX(va)   (((uint32_t)(va)>>12)&0x3ff)
#define OFF(va)   ((uint32_t)(va)&0xfff)


#define pmem_rw(addr, type) *(type *)({\
    Assert(addr < PMEM_SIZE, "physical address(0x%08x) is out of bound", addr); \
    guest_to_host(addr); \
    })

uint8_t pmem[PMEM_SIZE];
paddr_t page_mapping(vaddr_t addr, bool is_write);
/* Memory accessing interfaces */

// uint32_t paddr_read(paddr_t addr, int len) {
//   return pmem_rw(addr, uint32_t) & (~0u >> ((4 - len) << 3));
// }

// void paddr_write(paddr_t addr, int len, uint32_t data) {
//   memcpy(guest_to_host(addr), &data, len);
// }

uint32_t paddr_read(paddr_t addr, int len) {
  int r=is_mmio(addr);
  if(r==-1)
    return pmem_rw(addr, uint32_t) & (~0u >> ((4-len) << 3));
  else
    return mmio_read(addr,len,r);
}

void paddr_write(paddr_t addr, int len, uint32_t data) {
  int r=is_mmio(addr);
  if(r==-1)
    memcpy(guest_to_host(addr), &data, len);
  else
    mmio_write(addr,len,data,r);
}

uint32_t vaddr_read(vaddr_t addr, int len) {
  //return paddr_read(addr, len);
  if(PTE_ADDR(addr) != PTE_ADDR(addr + len - 1)) {
    //Assert(0, "[Error] Data across 2 pages: addr = 0x%x, len = %d. \n", addr, len);
    int len1 = 0x1000 - OFF(addr);
    int len2 = len - len1;
    paddr_t paddr1 = page_mapping(addr, false);
    paddr_t paddr2 = page_mapping(addr + len1, false);
    uint32_t low = paddr_read(paddr1, len1);
    uint32_t high = paddr_read(paddr2, len2);
    uint32_t res = high << (len1 * 8) | low;
    return res;
  }
  else {
    paddr_t paddr = page_mapping(addr, false);
    return paddr_read(paddr, len);
  }
}

void vaddr_write(vaddr_t addr, int len, uint32_t data) {
  //paddr_write(addr, len, data);
  if(PTE_ADDR(addr) != PTE_ADDR(addr + len - 1)) {
    //Assert(0, "[Error] Data across 2 pages: addr = 0x%x, len = %d. \n", addr, len);
    int len1 = 0x1000 - OFF(addr);
    int len2 = len - len1;
    paddr_t paddr1 = page_mapping(addr, true);
    paddr_t paddr2 = page_mapping(addr + len1, true);
    uint32_t low = data & (~0u >> ((4 - len1) << 3));
    uint32_t high = data >> ((4 - len2) << 3);
    paddr_write(paddr1, len1, low);
    paddr_write(paddr2, len2, high);
    return;
  }
  else {
    paddr_t paddr = page_mapping(addr, true);
    paddr_write(paddr, len, data);
  }
}
paddr_t page_mapping(vaddr_t addr, bool is_write) {
  CR0 cr0 = (CR0)cpu.CR0;
  if(cr0.paging && cr0.protect_enable) { // paging && pretection
    CR3 cr3 = (CR3)cpu.CR3;
    
    // PDE
    PDE* pgdirs = (PDE*)PTE_ADDR(cr3.val);
    PDE pde = (PDE)paddr_read((uint32_t)(pgdirs + PDX(addr)), 4);
    Assert(pde.present, "addr: 0x%x", addr);

    // PTE
    PTE* ptab = (PTE*)PTE_ADDR(pde.val);
    PTE pte = (PTE)paddr_read((uint32_t)(ptab + PTX(addr)), 4);
    Assert(pte.present, "addr: 0x%x", addr);

    // accessed and dirty
    pde.accessed = 1;
    pte.accessed = 1;
    if(is_write) {
      pte.dirty = 1;
    }
    paddr_t paddr = PTE_ADDR(pte.val) | OFF(addr);
#ifdef DEBUG
    //printf("va: 0x%x    pa: 0x%x\n", addr, paddr);
#endif 
    return paddr;
  }
  // No paging
  return addr;
}